{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# fastNLP中的 Vocabulary\n",
    "## 构建 Vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(['复', '旦', '大', '学'])  # 加入新的字\n",
    "vocab.add_word('上海')  # `上海`会作为一个整体\n",
    "vocab.to_index('复')  # 应该会为3\n",
    "vocab.to_index('我')  # 会输出1，Vocabulary中默认pad的index为0, unk(没有找到的词)的index为1\n",
    "\n",
    "#  在构建target的Vocabulary时，词表中应该用不上pad和unk，可以通过以下的初始化\n",
    "vocab = Vocabulary(unknown=None, padding=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Vocabulary(['positive', 'negative']...)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.add_word_lst(['positive', 'negative'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.to_index('positive')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 没有设置 unk 的情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "word `neutral` not in vocabulary",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-c6d424040b45>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'neutral'\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# 会报错，因为没有unk这种情况\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36mto_index\u001b[0;34m(self, w)\u001b[0m\n\u001b[1;32m    414\u001b[0m         \u001b[0;34m:\u001b[0m\u001b[0;32mreturn\u001b[0m \u001b[0mint\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mnumber\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    415\u001b[0m         \"\"\"\n\u001b[0;32m--> 416\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    417\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    418\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36m_wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     42\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_word2idx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrebuild\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_vocab\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, w)\u001b[0m\n\u001b[1;32m    272\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_word2idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munknown\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    273\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"word `{}` not in vocabulary\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    276\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0m_check_build_vocab\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: word `neutral` not in vocabulary"
     ]
    }
   ],
   "source": [
    "vocab.to_index('neutral')  # 会报错，因为没有unk这种情况"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 设置 unk 的情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, '<unk>')"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary(unknown='<unk>', padding=None)\n",
    "vocab.add_word_lst(['positive', 'negative'])\n",
    "vocab.to_index('neutral'), vocab.to_word(vocab.to_index('neutral'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Vocabulary(['positive', 'negative']...)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------------------------------------------------+--------+\n",
      "| chars                                             | target |\n",
      "+---------------------------------------------------+--------+\n",
      "| [4, 2, 2, 5, 6, 7, 3]                             | 0      |\n",
      "| [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 3] | 1      |\n",
      "+---------------------------------------------------+--------+\n"
     ]
    }
   ],
   "source": [
    "from fastNLP import Vocabulary\n",
    "from fastNLP import DataSet\n",
    "\n",
    "dataset = DataSet({'chars': [\n",
    "                                ['今', '天', '天', '气', '很', '好', '。'],\n",
    "                                ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']\n",
    "                            ],\n",
    "                    'target': ['neutral', 'negative']\n",
    "})\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.from_dataset(dataset, field_name='chars')\n",
    "vocab.index_dataset(dataset, field_name='chars')\n",
    "\n",
    "target_vocab = Vocabulary(padding=None, unknown=None)\n",
    "target_vocab.from_dataset(dataset, field_name='target')\n",
    "target_vocab.index_dataset(dataset, field_name='target')\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Vocabulary(['今', '天', '心', '情', '很']...)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fastNLP import Vocabulary\n",
    "from fastNLP import DataSet\n",
    "\n",
    "tr_data = DataSet({'chars': [\n",
    "                               ['今', '天', '心', '情', '很', '好', '。'],\n",
    "                               ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']\n",
    "                           ],\n",
    "                   'target': ['positive', 'negative']\n",
    "})\n",
    "dev_data = DataSet({'chars': [\n",
    "                               ['住', '宿', '条', '件', '还', '不', '错'],\n",
    "                               ['糟', '糕', '的', '天', '气', '，', '无', '法', '出', '行', '。']\n",
    "                           ],\n",
    "                   'target': ['positive', 'negative']\n",
    "})\n",
    "\n",
    "vocab = Vocabulary()\n",
    "#  将验证集或者测试集在建立词表是放入no_create_entry_dataset这个参数中。\n",
    "vocab.from_dataset(tr_data, field_name='chars', no_create_entry_dataset=[dev_data])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▎         | 2.31M/63.5M [00:00<00:02, 22.9MB/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "http://212.129.155.247/embedding/glove.6B.50d.zip not found in cache, downloading to /tmp/tmpvziobj_e\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 63.5M/63.5M [00:01<00:00, 41.3MB/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finish download from http://212.129.155.247/embedding/glove.6B.50d.zip\n",
      "Copy file to /remote-home/ynzheng/.fastNLP/embedding/glove.6B.50d\n",
      "Found 2 out of 6 words in the pre-training embedding.\n",
      "tensor([[ 0.9497,  0.3433,  0.8450, -0.8852, -0.7208, -0.2931, -0.7468,  0.6512,\n",
      "          0.4730, -0.7401,  0.1877, -0.3828, -0.5590,  0.4295, -0.2698, -0.4238,\n",
      "         -0.3124,  1.3423, -0.7857, -0.6302,  0.9182,  0.2113, -0.5744,  1.4549,\n",
      "          0.7546, -1.6165, -0.0085,  0.0029,  0.5130, -0.4745,  2.5306,  0.8594,\n",
      "         -0.3067,  0.0578,  0.6623,  0.2080,  0.6424, -0.5246, -0.0534,  1.1404,\n",
      "         -0.1370, -0.1836,  0.4546, -0.5096, -0.0255, -0.0286,  0.1805, -0.4483,\n",
      "          0.4053, -0.3682]], grad_fn=<EmbeddingBackward>)\n",
      "tensor([[ 0.1320, -0.2392,  0.1732, -0.2390, -0.0463,  0.0494,  0.0488, -0.0886,\n",
      "          0.0224, -0.1300,  0.0369,  0.1800,  0.0750, -0.0183,  0.2264,  0.1628,\n",
      "          0.1261, -0.1259,  0.1663, -0.1230, -0.1904, -0.0532,  0.1397, -0.0259,\n",
      "         -0.1799,  0.0226,  0.1858,  0.1981,  0.1338,  0.2394,  0.0248,  0.0203,\n",
      "         -0.1722, -0.1683, -0.1892,  0.0874,  0.0562, -0.0394,  0.0306, -0.1761,\n",
      "          0.1015, -0.0171,  0.1172,  0.1357,  0.1519, -0.0011,  0.1572,  0.1265,\n",
      "         -0.2391, -0.0258]], grad_fn=<EmbeddingBackward>)\n",
      "tensor([[ 0.1318, -0.2552, -0.0679,  0.2619, -0.2616,  0.2357,  0.1308, -0.0118,\n",
      "          1.7659,  0.2078,  0.2620, -0.1643, -0.8464,  0.0201,  0.0702,  0.3978,\n",
      "          0.1528, -0.2021, -1.6184, -0.5433, -0.1786,  0.5389,  0.4987, -0.1017,\n",
      "          0.6626, -1.7051,  0.0572, -0.3241, -0.6683,  0.2665,  2.8420,  0.2684,\n",
      "         -0.5954, -0.5004,  1.5199,  0.0396,  1.6659,  0.9976, -0.5597, -0.7049,\n",
      "         -0.0309, -0.2830, -0.1356,  0.6429,  0.4149,  1.2362,  0.7659,  0.9780,\n",
      "          0.5851, -0.3018]], grad_fn=<EmbeddingBackward>)\n",
      "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0.]], grad_fn=<EmbeddingBackward>)\n",
      "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0.]], grad_fn=<EmbeddingBackward>)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from fastNLP.embeddings import StaticEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word('train')\n",
    "vocab.add_word('only_in_train')  # 仅在train出现，但肯定在预训练词表中不存在\n",
    "vocab.add_word('test', no_create_entry=True)  # 该词只在dev或test中出现\n",
    "vocab.add_word('only_in_test', no_create_entry=True)  # 这个词在预训练的词表中找不到\n",
    "\n",
    "embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
    "print(embed(torch.LongTensor([vocab.to_index('train')])))\n",
    "print(embed(torch.LongTensor([vocab.to_index('only_in_train')])))\n",
    "print(embed(torch.LongTensor([vocab.to_index('test')])))\n",
    "print(embed(torch.LongTensor([vocab.to_index('only_in_test')])))\n",
    "print(embed(torch.LongTensor([vocab.unknown_idx])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python Now",
   "language": "python",
   "name": "now"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
